39d8d7
@@ -41,7 +41,7 @@
 import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable;
 import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveIn;
 import org.apache.hadoop.hive.ql.plan.ColStatistics;
-import org.apache.hadoop.hive.ql.plan.ColStatistics.Range;
+import org.apache.hadoop.hive.ql.stats.StatsUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -56,8 +56,10 @@
  * we can infer that the predicate will evaluate to false if the max
  * value for column a is 4.
  *
- * Currently we support the simplification of =, >=, <=, >, <, and
- * IN operations.
+ * Currently we support the simplification of:
+ *  - =, >=, <=, >, <
+ *  - IN
+ *  - IS_NULL / IS_NOT_NULL
  */
 public class HiveReduceExpressionsWithStatsRule extends RelOptRule {
 
@@ -243,9 +245,28 @@
public RexNode visitCall(RexCall call) {
           }
           return rexBuilder.makeCall(HiveIn.INSTANCE, newOperands);
         }
-
         // We cannot apply the reduction
         return call;
+      } else if (call.getOperator().getKind() == SqlKind.IS_NULL || call.getOperator().getKind() == SqlKind.IS_NOT_NULL) {
+        SqlKind kind = call.getOperator().getKind();
+
+        if (call.operands.get(0) instanceof RexInputRef) {
+          RexInputRef ref = (RexInputRef) call.operands.get(0);
+
+          ColStatistics stat = extractColStats(ref);
+          Long rowCount = extractRowCount(ref);
+          if (stat != null && rowCount != null) {
+            if (stat.getNumNulls() == 0 || stat.getNumNulls() == rowCount) {
+              boolean allNulls = (stat.getNumNulls() == rowCount);
+
+              if (kind == SqlKind.IS_NULL) {
+                return rexBuilder.makeLiteral(allNulls);
+              } else {
+                return rexBuilder.makeLiteral(!allNulls);
+              }
+            }
+          }
+        }
       }
 
       // If we did not reduce, check the children nodes
@@ -257,8 +278,18 @@
public RexNode visitCall(RexCall call) {
     }
 
     private Pair<Number,Number> extractMaxMin(RexInputRef ref) {
+
+      ColStatistics cs = extractColStats(ref);
       Number max = null;
       Number min = null;
+      if (cs != null && cs.getRange()!=null) {
+        max = cs.getRange().maxValue;
+        min = cs.getRange().minValue;
+      }
+      return Pair.<Number, Number> of(max, min);
+    }
+
+    private ColStatistics extractColStats(RexInputRef ref) {
       RelColumnOrigin columnOrigin = this.metadataProvider.getColumnOrigin(filterOp, ref.getIndex());
       if (columnOrigin != null) {
         RelOptHiveTable table = (RelOptHiveTable) columnOrigin.getOriginTable();
@@ -267,15 +298,24 @@
public RexNode visitCall(RexCall call) {
                   table.getColStat(Lists.newArrayList(columnOrigin.getOriginColumnOrdinal())).get(0);
           if (colStats != null && StatsSetupConst.areColumnStatsUptoDate(
                   table.getHiveTableMD().getParameters(), colStats.getColumnName())) {
-            Range range = colStats.getRange();
-            if (range != null) {
-              max = range.maxValue;
-              min = range.minValue;
-            }
+            return colStats;
+          }
+        }
+      }
+      return null;
+    }
+
+    private Long extractRowCount(RexInputRef ref) {
+      RelColumnOrigin columnOrigin = this.metadataProvider.getColumnOrigin(filterOp, ref.getIndex());
+      if (columnOrigin != null) {
+        RelOptHiveTable table = (RelOptHiveTable) columnOrigin.getOriginTable();
+        if (table != null) {
+          if (StatsSetupConst.areBasicStatsUptoDate(table.getHiveTableMD().getParameters())) {
+            return StatsUtils.getNumRows(table.getHiveTableMD());
           }
         }
       }
-      return Pair.<Number,Number>of(max, min);
+      return null;
     }
 
     @SuppressWarnings("unchecked")
